Source code for hysop.symbolic.spectral
# Copyright (c) HySoP 2011-2024
#
# This file is part of HySoP software.
# See "https://particle_methods.gricad-pages.univ-grenoble-alpes.fr/hysop-doc/"
# for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sympy as sm
import numpy as np
from hysop.constants import BoundaryCondition, BoundaryExtension, TransformType
from hysop.tools.htypes import check_instance, to_tuple, first_not_None
from hysop.tools.sympy_utils import Expr, Symbol, Dummy, subscript
from hysop.tools.spectral_utils import SpectralTransformUtils as STU
from hysop.symbolic import SpaceSymbol
from hysop.symbolic.array import SymbolicBuffer
from hysop.symbolic.field import (
FieldExpressionBuilder,
FieldExpressionI,
TensorBase,
SymbolicField,
AppliedSymbolicField,
)
from hysop.symbolic.frame import SymbolicFrame
from hysop.fields.continuous_field import Field, ScalarField, TensorField
from hysop.tools.spectral_utils import SpectralTransformUtils
[docs]
class WaveNumberIndex(sm.Symbol):
def __new__(cls, axis):
obj = super().__new__(cls, f"i{axis}")
obj.axis = axis
obj._axes = None
obj._real_index = None
return obj
def __init__(self, axis):
super().__init__()
[docs]
def bind_axes(self, axes):
assert (self._axes is None) or (axes == self._axes)
dim = len(axes)
from hysop.symbolic import local_indices_symbols
self._axes = axes
self._real_index = local_indices_symbols[dim - 1 - axes.index(self.axis)]
@property
def real_index(self):
if self._real_index is None:
msg = "No axes bound yet !"
raise RuntimeError(msg)
return self._real_index
[docs]
class WaveNumber(Dummy):
"""Wave number symbol for SpectralTransform derivatives (and integrals)."""
__transform2str = {
TransformType.FFT: "c2c",
TransformType.RFFT: "r2c",
TransformType.DCT_I: "c1",
TransformType.DCT_II: "c2",
TransformType.DCT_III: "c3",
TransformType.DCT_IV: "c4",
TransformType.DST_I: "s1",
TransformType.DST_II: "s2",
TransformType.DST_III: "s3",
TransformType.DST_IV: "s4",
TransformType.IFFT: "c2c",
TransformType.IRFFT: "r2c",
TransformType.IDCT_I: "c1",
TransformType.IDCT_II: "c3",
TransformType.IDCT_III: "c2",
TransformType.IDCT_IV: "c4",
TransformType.IDST_I: "s1",
TransformType.IDST_II: "s3",
TransformType.IDST_III: "s2",
TransformType.IDST_IV: "s4",
}
__wave_numbers = {}
def __new__(cls, axis, transform, exponent, **kwds):
check_instance(transform, TransformType)
check_instance(axis, int, minval=0)
check_instance(exponent, int, minval=1)
if transform is TransformType.NONE:
return None
if exponent == 0:
return 1
key = (transform, axis, exponent)
if key in cls.__wave_numbers:
return cls.__wave_numbers[key]
tr_str = cls.__transform2str[transform]
if len(tr_str) == 2:
tr_pstr = tr_str[0] + subscript(int(tr_str[1]))
else:
tr_pstr = tr_str
name = f"k{axis}_{tr_str}"
pretty_name = "k" + subscript(axis) + "_" + tr_pstr
if exponent < 0:
name = "i" + name
pretty_name = "i" + pretty_name
exponent = -exponent
if exponent > 1:
name += f"__{exponent}"
pretty_name += f"__{exponent}"
obj = super().__new__(cls, name=name, pretty_name=pretty_name, **kwds)
obj._axis = int(axis)
obj._transform = transform
obj._exponent = int(exponent)
cls.__wave_numbers[key] = obj
return obj
def __init__(self, axis, transform, exponent, **kwds):
super().__init__(name=None, pretty_name=None, **kwds)
@property
def axis(self):
return self._axis
@property
def transform(self):
return self._transform
@property
def exponent(self):
return self._exponent
@property
def is_real(self):
tr = self._transform
exp = self._exponent
is_real = STU.is_R2R(tr)
is_real |= (not STU.is_R2R(tr)) and (exp % 2 == 0)
return is_real
@property
def is_complex(self):
tr = self._transform
exp = self._exponent
return (not STU.is_R2R(tr)) and (exp % 2 != 0)
[docs]
def pow(self, exponent):
exponent *= self.exponent
return WaveNumber(axis=self.axis, transform=self.transform, exponent=exponent)
[docs]
def indexed_buffer(self, name=None):
name = first_not_None(name, self.name)
buf = SymbolicBuffer(name=name, memory_object=None)
idx = WaveNumberIndex(self.axis)
obj = buf[idx]
obj.Wn = self
return obj
def __eq__(self, other):
if not isinstance(other, WaveNumber):
return NotImplemented
eq = self.axis == other.axis
eq &= self.transform == other.transform
eq &= self.exponent == other.exponent
return eq
def __hash__(self):
return hash((self.axis, self.transform, self.exponent))
[docs]
class AppliedSpectralTransform(AppliedSymbolicField):
"""
An applied spectral transform.
"""
[docs]
def short_description(self):
ss = "{}(field={}, axes={}, is_forward={}, transforms=[{}])"
return ss.format(
self.__class__.__name__,
self.field.pretty_name,
self.transformed_axes,
"1" if self.is_forward else "0",
self.format_transforms(),
)
[docs]
def long_description(self):
ss = """
== {} ==
*field: {}
*transformed_axes: {}
*spatial_axes: {}
*is_forward: {}
*transforms: {}
*freq_vars: {}
*space_vars: {}
*all_vars: {}
*wave_numbers: {}
"""
return ss.format(
self.__class__.__name__,
self.field.short_description(),
self.transformed_axes,
self.spatial_axes,
self.is_forward,
self.transforms,
self.space_vars,
self.freq_vars,
self.all_vars,
self.wave_numbers,
)
[docs]
def format_transforms(self):
transforms = self.transforms
return " x ".join(str(tr) for tr in transforms)
@property
def field(self):
return self._field
@property
def transformed_axes(self):
return self._transformed_axes
@property
def spatial_axes(self):
return self._spatial_axes
@property
def freq_vars(self):
return self._freq_vars
@property
def space_vars(self):
return self._space_vars
@property
def all_vars(self):
return self._all_vars
@property
def frame(self):
return self._frame
@property
def lboundaries(self):
return self._field.lboundaries
@property
def rboundaries(self):
return self._field.rboundaries
@property
def domain(self):
return self._field.domain
@property
def dtype(self):
return self._field.dtype
@property
def transforms(self):
return self._transforms
@property
def wave_numbers(self):
return self._wave_numbers
@property
def is_forward(self):
return self._is_forward
# SYMPY INTERNALS ################
@property
def is_number(self):
return False
@property
def free_symbols(self):
return set(self._all_vars)
def _eval_derivative(self, v):
if v in self._freq_vars:
i = self._all_vars.index(v)
return self._wave_numbers[i] * self
return sm.Derivative(self, v)
def _hashable_content(self):
"""See sympy.core.basic.Basic._hashable_content()"""
hc = super()._hashable_content()
hc += (self.__class__,)
return hc
def __hash__(self):
h = super().__hash__()
for hc in (self.__class__,):
h ^= hash(h)
return h
[docs]
def __eq__(self, other):
"Fix sympy v1.2 eq"
eq = super().__eq__(other)
if eq is not True:
return eq
eq &= self.__class__ is other.__class__
return eq
###################################
[docs]
class SpectralTransform(SymbolicField):
"""
A single spectral transform that may be applied.
This object can also be used as am sympy expression (and a FieldExpression).
This expression carries datatype and boundary conditions.
"""
def __new__(cls, field, axes=None, forward=True):
if isinstance(field, TensorField):
T = field.new_empty_array()
wave_numbers = ()
for idx, f in field.nd_iter():
T[idx] = cls(field=f, axes=axes, forward=forward)
wave_numbers += T[idx].wave_numbers
T = T.view(TensorBase)
T.frame = T[0].frame
return T
dim = field.dim
check_instance(field, ScalarField)
axes = to_tuple(first_not_None(axes, range(field.dim)))
check_instance(axes, tuple, values=int, minval=0, maxval=dim - 1, minsize=1)
transformed_axes = tuple(sorted(set(axes)))
spatial_axes = tuple(sorted(set(range(field.dim)) - set(axes)))
frame = field.domain.frame
freq_vars = tuple(frame.freqs[dim - 1 - i] for i in transformed_axes[::-1])
space_vars = tuple(frame.coords[dim - 1 - i] for i in spatial_axes[::-1])
all_vars = ()
for i in range(dim):
if i in transformed_axes:
all_vars += (frame.freqs[dim - 1 - i],)
else:
all_vars += (frame.coords[dim - 1 - i],)
all_vars = all_vars[::-1]
transforms = SpectralTransformUtils.transforms_from_field(
field, transformed_axes=transformed_axes
)
for i in range(frame.dim):
assert (transforms[i] is TransformType.NONE) ^ (i in transformed_axes)
wave_numbers = cls.generate_wave_numbers(transforms)[::-1]
if not forward:
transforms = SpectralTransformUtils.get_inverse_transforms(*transforms)
frame = SymbolicFrame(dim=field.dim, freq_axes=transformed_axes)
assert frame.coords == all_vars
obj = super().__new__(cls, field=field, bases=(AppliedSpectralTransform,))
obj._field = field
obj._transformed_axes = transformed_axes
obj._spatial_axes = spatial_axes
obj._freq_vars = freq_vars
obj._space_vars = space_vars
obj._is_forward = forward
obj._all_vars = all_vars
obj._transforms = transforms
obj._wave_numbers = wave_numbers
obj._frame = frame
return obj(*all_vars)
[docs]
@classmethod
def generate_wave_numbers(cls, transforms):
return SpectralTransformUtils.generate_wave_numbers(*transforms)
def _hashable_content(self):
"""See sympy.core.basic.Basic._hashable_content()"""
hc = super()._hashable_content()
hc += (self._transformed_axes, self._is_forward)
return hc
[docs]
def __hash__(self):
"Fix sympy v1.2 hashes"
h = super().__hash__()
for hc in (self._transformed_axes, self._is_forward):
h ^= hash(hc)
return h
[docs]
def __eq__(self, other):
"Fix sympy v1.2 eq"
eq = super().__eq__(other)
if eq is not True:
return eq
for lhc, rhc in zip(
(self._transformed_axes, self._is_forward),
(other._transformed_axes, other._is_forward),
):
eq &= lhc == rhc
return eq
if __name__ == "__main__":
from hysop.tools.sympy_utils import sstr
from hysop import Box
from hysop.constants import BoxBoundaryCondition
from hysop.defaults import VelocityField, VorticityField
from hysop.symbolic.field import laplacian, curl
from hysop.symbolic.relational import Assignment
from hysop.tools.sympy_utils import Greak
dim = 3
d = Box(
dim=dim,
lboundaries=(
BoxBoundaryCondition.SYMMETRIC,
BoxBoundaryCondition.OUTFLOW,
BoxBoundaryCondition.SYMMETRIC,
),
rboundaries=(
BoxBoundaryCondition.SYMMETRIC,
BoxBoundaryCondition.OUTFLOW,
BoxBoundaryCondition.OUTFLOW,
),
)
U = VelocityField(domain=d)
W = VorticityField(velocity=U)
psi = W.field_like(name="psi", pretty_name=Greak[23])
W_hat = SpectralTransform(W, forward=True)
U_hat = SpectralTransform(U, forward=False)
psi_hat = SpectralTransform(psi)
eqs = laplacian(psi_hat, psi_hat.frame) - W_hat
sol = sm.solve(eqs, psi_hat.tolist())
sol = curl(psi_hat, psi_hat.frame).xreplace(sol)
print("VELOCITY")
print(U.short_description())
print()
print("VORTICITY")
print(W.short_description())
print()
print("W_hat")
print(W_hat)
print()
print("U_hat")
print(U_hat)
print()
print("Psi_hat")
print(psi_hat)
print()
for eq in Assignment.assign(U_hat, sol):
eq, trs, wn = SpectralTransformUtils.parse_expression(eq)
print()
print(eq)
for tr in trs:
print(tr.short_description())
print(wn)